import numpy as np
from scipy.optimize import minimize, LinearConstraint
import matplotlib.pyplot as plt
import matplotlib.cm as cm

def compute_direction_p(mono_lang_dict, lang_list, D):
    """
    Compute the normalized effective allocation proportions (p_i) for a given subset of languages.

    Parameters
    ----------
    mono_lang_dict : dict
        Dictionary mapping each language code (str) to its monolingual scaling-law parameters:
        mono_lang_dict = {
            language_code: [B_i, beta_i, E_i],
            ...
        }
        Each parameter set:
            - B_i: Scaling-law coefficient.
            - beta_i: Power-law exponent.
            - E_i: Constant offset term.

    lang_list : list of str
        List of language codes for which p_i is calculated.

    D : float
        Total training token budget (hyperparameter).

    Returns
    -------
    p_dict : dict
        Dictionary mapping language codes to their computed normalized proportions p_i, 
        satisfying sum(p_i) = 1.
    """
    # Extract B_i and beta_i for each language
    Bs = np.array([mono_lang_dict[lang][0] for lang in lang_list])
    betas = np.array([-mono_lang_dict[lang][1] for lang in lang_list])

    # Calculate unnormalized weights:
    # w_i = (B_i * beta_i / D^{beta_i})^{1/(beta_i+1)}
    numer = Bs * betas
    denom = D ** betas
    w = (numer / denom) ** (1.0 / (betas + 1.0))

    # Normalize weights to obtain proportions p_i
    p = w / w.sum()

    return {lang: float(p_i) for lang, p_i in zip(lang_list, p)}


from scipy.optimize import minimize, LinearConstraint

def optimize_effective_mix(relation_dict, p_dict, D, rho=10.0, method='trust-constr'):
    """
    Given inter-language mapping parameters and desired proportions p_i, 
    optimize to maximize the sum of tilde_r_i with a soft constraint ensuring that the normalized tilde_r proportions are close to p_i.

    Parameters
    ----------
    relation_dict : dict
        Nested dictionary capturing inter-language interaction parameters, structured as follows:
        
        relation_dict = {
            lang_i: {
                'alpha': {
                    lang_j: [a_ij, b_ij],  # Parameters for mapping from language j to language i
                    ...
                },
                'eta': eta_i,              # Language-specific saturation parameter
            },
            ...
        }

        Interpretation of parameters:
            - alpha[lang_j] contains parameters [a_ij, b_ij] used to calculate the interaction term from language j to language i as:
              alpha_ij = a_ij + (b_ij / D)


    p_dict : dict
        Desired normalized effective language proportions {lang: p_i}, satisfying sum(p_i) = 1.

    D : float
        Total training token budget (can be utilized within the mapping functions if needed).

    rho : float, optional
        Penalty coefficient enforcing closeness of optimized proportions to p_i (default=10.0).

    method : str, optional
        Optimization algorithm to use ('trust-constr' or 'SLSQP'), default='trust-constr'.

    Returns
    -------
    dict containing:
        'r': dict
            Optimized raw language proportions {lang: r_i}, satisfying sum(r_i) = 1.
        'tilde_r': dict
            Optimized effective proportions {lang: tilde_r_i}.
    """
    langs = list(relation_dict.keys())
    n = len(langs)

    # Target proportion vector
    p = np.array([p_dict[lg] for lg in langs])
    p /= p.sum()  # Ensure normalization

    # Expand interaction parameters into arrays
    alpha_mat = np.zeros((n, n))
    eta_vec = np.zeros(n)

    for i, li in enumerate(langs):
        eta_vec[i] = relation_dict[li]['eta']
        for j, lj in enumerate(langs):
            if lj == li:
                continue
            a, b = relation_dict[li]['alpha'][lj]
            alpha_mat[i, j] = a + (b / D)

    def compute_tilde(r):
        """
        Compute tilde_r_i for each language:
        
        tilde_r_i = [sum_{j≠i} alpha_{i,j} * r_j] * [1 - exp(-eta_i * r_i)] + r_i
        """
        tilde = np.zeros(n)
        for i in range(n):
            cross = np.dot(alpha_mat[i], r) - alpha_mat[i, i] * r[i]
            tilde[i] = cross * (1 - np.exp(-eta_vec[i] * r[i])) + r[i]
        return tilde

    def objective(r):
        """
        Objective to maximize total tilde_r while penalizing deviation from desired proportions p_i.
        """
        tilde = compute_tilde(r)
        total = tilde.sum()
        ratio = tilde / total
        penalty = rho * np.sum((ratio - p) ** 2)
        return -total + penalty

    # Constraint: sum(r_i) = 1
    lc = LinearConstraint(np.ones(n), 1, 1)

    # Initial values: Start from target proportions
    x0 = p.copy()
    bounds = [(0, 1)] * n

    sol = minimize(objective, x0,
                   method=method,
                   bounds=bounds,
                   constraints=[lc],
                   options={'gtol': 1e-12, 'xtol': 1e-12, 'maxiter': 500})

    r_opt = sol.x
    tilde_opt = compute_tilde(r_opt)

    # Return results as dictionaries
    r_dict = {langs[i]: float(r_opt[i]) for i in range(n)}
    tilde_dict = {langs[i]: float(tilde_opt[i]) for i in range(n)}

    return {'r': r_dict, 'tilde_r': tilde_dict}
